#!/usr/bin/env python
# -*- coding: utf-8 -*-

__author__ = 'James Iter'
__date__ = '15/12/20'
__contact__ = 'james.iter.cn@gmail.com'
__copyright__ = '(c) 2015 by James Iter.'

import mysql.connector
from mysql.connector import errorcode
from random import choice
import json
import copy
import time, datetime
import jimit as ji

from initialize import app, ts, logger, object_db_map, objects_model, cold_db_conn_map_list, \
    cold_db_conn_map_list_reversed, regex_sql_str
from models import Rules, Utils
import states


class RecodeOperate(object):

    def __init__(self, name=None, kv=None, filter_str=None):
        self.name = name
        self.kv = kv
        self.filter_str = filter_str
        self.i = 0

    def previewing(self):
        args_rules = [
            (basestring, 'name', objects_model['rules'].keys())
        ]
        ji.Check.previewing(args_rules, self.__dict__)

    def previewing_kv(self):
        # TODO: 需优化,避免每次都计算
        args_rules = objects_model['rules'][self.name].values()
        ji.Check.previewing(args_rules, self.kv)

    # 检测仅存在的值是否合理
    def previewing_has_kv(self):
        object_rules = objects_model['rules'][self.name]
        args_rules = []
        for key, value in object_rules.items():
            if key in self.kv and key != 'id':
                args_rules.append(value)

        ji.Check.previewing(args_rules, self.kv)

    def previewing_filter_str(self):
        args_rules = [
            (basestring, 'filter_str', (8, 1000))
        ]
        ji.Check.previewing(args_rules, self.__dict__)

        filter_keys = []
        for dsl in self.filter_str.split(','):
            args_rules = [
                Rules.DSL.value
            ]
            ji.Check.previewing(args_rules, locals())
            if dsl[0:2] == '__':
                continue
            ele_s = dsl.split('__')
            filter_keys.append(ele_s[0])

        def indexed(field=None, origin_field=None):
            _args_rules = [
                (basestring, 'field'),
                (basestring, 'origin_field')
            ]
            ji.Check.previewing(_args_rules, locals())

            if field in objects_model['index'][self.name]:
                if self.i > 0:
                    for _field in field.split('__'):
                        if _field not in filter_keys:
                            return indexed('_', origin_field)

                self.i = 0
                return True
            elif self.i < objects_model['index'][self.name].__len__():
                field = '__'.join([objects_model['index'][self.name][self.i], origin_field])
                self.i += 1
                return indexed(field, origin_field)
            else:
                self.i = 0
                return False

        args_rules = []
        for k in filter_keys:
            if not indexed(k, k):
                ret = dict()
                ret['state'] = ji.Common.exchange_state(41250)
                ret['state']['sub']['zh-cn'] = ''.join([ret['state']['sub']['zh-cn'], ': ', k])
                logger.warning(json.dumps(ret))
                raise ji.PreviewingError(json.dumps(ret))

    def create(self):
        self.previewing()
        self.previewing_kv()

        def insert(_cnx=None):
            if _cnx is None:
                raise ValueError('cnx must not None')

            keys = []
            values = []
            object_rules = objects_model['rules'][self.name]
            for k in objects_model['rules'][self.name].keys():
                keys.append(k)
                values.append(str(RecodeOperate.get_fit_statement(object_rules, k, self.kv[k])))

            cursor = _cnx.cursor(dictionary=True, buffered=False)
            sql_stmt = ''.join(['INSERT INTO ', self.name, ' (', ', '.join(keys), ') VALUES (', ', '.join(values), ');'])
            try:
                cursor.execute(sql_stmt)
                _cnx.commit()
            except mysql.connector.Error as err:
                ret = dict()
                ret['state'] = ji.Common.exchange_state(50050)
                ret['state']['sub']['zh-cn'] = ''.join([ret['state']['sub']['zh-cn'], ': ', err._full_msg])
                logger.error(sql_stmt)
                logger.error(err)
                raise ji.PreviewingError(json.dumps(ret))
            finally:
                cursor.close()
                _cnx.close()

        object_db = object_db_map[self.name]

        if object_db['rw_mode'] == states.RWMode.BW_AR.value:
            # 根据写库的数量来取余对象id,来选择插入的库.这样的好处是在更新时,只需对存在的库执行语句就OK了
            entry = self.kv['id'] % object_db['db_group_s'][0].__len__()
            cnxpool = object_db['db_group_s'][0][entry]
            cnx = cnxpool.get_connection()
            insert(_cnx=cnx)

        elif object_db['rw_mode'] == states.RWMode.SW_SR.value:
            object_sw_list = object_db['db_group_s'][0]
            for cnxpool in object_sw_list:
                cnx = cnxpool.get_connection()
                insert(_cnx=cnx)
        else:
            raise OSError('What?')

    def update(self):
        # 冷库数据不支持删,改
        self.previewing()
        self.previewing_has_kv()

        def _update(_cnx=None):
            if _cnx is None:
                raise ValueError('cnx must not None')

            sub_sql_stmt = []
            object_rules = objects_model['rules'][self.name]
            for k in object_rules.keys():
                if k not in self.kv or k == 'id':
                    continue
                sub_sql_stmt.append(' = '.join([k, str(RecodeOperate.get_fit_statement(object_rules, k, self.kv[k]))]))

            # 长度等于0表示不更新任何字段
            if sub_sql_stmt.__len__() > 0:
                sub_sql_stmt = ', '.join(sub_sql_stmt)
                sub_sql_stmt = ''.join([sub_sql_stmt, ' WHERE id = ', str(self.kv['id'])])

                cursor = _cnx.cursor(dictionary=True, buffered=False)
                sql_stmt = ''.join(['UPDATE ', self.name, ' SET ', sub_sql_stmt])
                try:
                    cursor.execute(sql_stmt)
                    _cnx.commit()
                except mysql.connector.Error as err:
                    ret = dict()
                    ret['state'] = ji.Common.exchange_state(50050)
                    ret['state']['sub']['zh-cn'] = ''.join([ret['state']['sub']['zh-cn'], ': ', err._full_msg])
                    logger.error(sql_stmt)
                    logger.error(err)
                    raise ji.PreviewingError(json.dumps(ret))
                finally:
                    cursor.close()
                    _cnx.close()

        object_db = object_db_map[self.name]

        if object_db['rw_mode'] == states.RWMode.SW_SR.value:
            object_sw_list = object_db['db_group_s'][0]
            for cnxpool in object_sw_list:
                cnx = cnxpool.get_connection()
                _update(_cnx=cnx)

        elif object_db['rw_mode'] == states.RWMode.BW_AR.value:
            # TODO: 参见insert处
            entry = self.kv['id'] % object_db['db_group_s'][0].__len__()
            cnxpool = object_db['db_group_s'][0][entry]
            cnx = cnxpool.get_connection()
            self.filter_str = 'id__eq__' + str(self.kv['id'])
            sql_stmt, order_key, desc, limit, begin_at, end_at = RecodeOperate.get_select_sql_stmt(self.name,
                                                                                                   self.filter_str)
            rows = RecodeOperate.sql_stmt_fetchall(_cnx=cnx, _sql_stmt=sql_stmt)

            if rows.__len__() < 1:
                object_bw_list = object_db['db_group_s'][0]
                for cnxpool in object_bw_list:
                    cnx = cnxpool.get_connection()
                    _update(_cnx=cnx)

        else:
            raise OSError('What?')

    def get(self):
        self.previewing()
        self.previewing_filter_str()

        def select(_object_db=None, _sql_stmt=None, _order_key=None, _desc=None, _limit=None):
            if None in [_object_db, _sql_stmt]:
                raise ValueError(''.join([', '.join(['_object_db', '_sql_stmt']), ' must not None']))

            if _object_db['rw_mode'] == states.RWMode.SW_SR.value:
                _cnxpool = choice(_object_db['db_group_s'][1])
                _cnx = _cnxpool.get_connection()
                return RecodeOperate.sql_stmt_fetchall(_cnx=_cnx, _sql_stmt=_sql_stmt)

            elif _object_db['rw_mode'] == states.RWMode.BW_AR.value:
                # TODO: _order_key的默认值为查询对象的主键,如果_order_key为None
                if None in [_order_key, _desc, _limit]:
                    raise ValueError(''.join([', '.join(['_order_key', '_desc', '_limit']), ' must not None']))

                # TODO: 待测
                object_bw_ar_list = choice(_object_db['db_group_s'][1:])
                _rows = []
                for _cnxpool in object_bw_ar_list:
                    _cnx = _cnxpool.get_connection()
                    _rows.extend(RecodeOperate.sql_stmt_fetchall(_cnx=_cnx, _sql_stmt=_sql_stmt))

                _rows_sorted = sorted(_rows, key=lambda item: item[_order_key], reverse=_desc)
                return _rows_sorted[:_limit]

        def select_with_cold(_sql_stmt=None, _order_key=None, _desc=None, _limit=None, _begin_at=None, _end_at=None):
            if None in [_sql_stmt, _order_key, _desc, _limit, _begin_at, _end_at]:
                raise ValueError(''.join([', '.join(['_object_db', '_sql_stmt', '_order_key', '_desc', '_limit',
                                                     '_begin_at', '_end_at']), ' must not None']))

            _begin_at = int(Utils.ts_to_date(_begin_at) / 100)
            _end_at = int(Utils.ts_to_date(_end_at) / 100)

            _rows = list()
            cur_db_conn_map_list = cold_db_conn_map_list
            if _desc:
                cur_db_conn_map_list = cold_db_conn_map_list_reversed

            for item in cur_db_conn_map_list:
                if _end_at < item['sequence'] or item['sequence'] < _begin_at:
                    continue

                _cnxpool = choice(item['cnx_s'])
                _cnx = _cnxpool.get_connection()
                _rows.extend(RecodeOperate.sql_stmt_fetchall(_cnx=_cnx, _sql_stmt=_sql_stmt))
                # 即使本次获取的记录数为0, 也要继续从后续的库中查询, 因为某个中间库中, 可能就没有一条复合条件的记录
                if _rows.__len__() > _limit:
                    _rows = _rows[:limit]
                    break
                elif _rows.__len__() == _limit:
                    break

            return _rows

        sql_stmt, order_key, desc, limit, begin_at, end_at = RecodeOperate.get_select_sql_stmt(self.name,
                                                                                               self.filter_str)

        if object_db_map[self.name]['domain'] == states.DBDomain.hot.value:
            this_cycle_begin_ts = Utils.get_the_cycle_begin_ts(cycle_unit=app.config['DUMP_CYCLE'], offset=0)
            last_cycle_begin_ts = Utils.get_the_cycle_begin_ts(cycle_unit=app.config['DUMP_CYCLE'], offset=1)
            date_fly_time_end_ts = this_cycle_begin_ts + app.config['DATA_FLY_TIME']

            if (ts < date_fly_time_end_ts and last_cycle_begin_ts < begin_at) or \
                    (ts > date_fly_time_end_ts and this_cycle_begin_ts < begin_at):
                # 不涉及冷库
                rows = select(_object_db=object_db_map[self.name], _sql_stmt=sql_stmt, _order_key=order_key, _desc=desc,
                              _limit=limit)

            elif (ts < date_fly_time_end_ts and end_at < last_cycle_begin_ts) or \
                    (ts > date_fly_time_end_ts and end_at < this_cycle_begin_ts):
                rows = select_with_cold(_sql_stmt=sql_stmt, _order_key=order_key, _desc=desc, _limit=limit,
                                        _begin_at=begin_at, _end_at=end_at)

            else:
                # 冷\热联合
                # 根据desc来决定, (先热后冷 or 先冷后热)
                if desc:
                    rows = select(_object_db=object_db_map[self.name], _sql_stmt=sql_stmt, _order_key=order_key,
                                  _desc=desc, _limit=limit)
                else:
                    rows = select_with_cold(_sql_stmt=sql_stmt, _order_key=order_key, _desc=desc, _limit=limit,
                                            _begin_at=begin_at, _end_at=end_at)

                cut_id = 0
                if rows.__len__() > 0:
                    cut_id = rows[-1]['id']

                if rows.__len__() < limit:
                    sql_stmt_spilt = sql_stmt.split('WHERE')
                    if desc:
                        sql_stmt = ''.join([sql_stmt_spilt[0], ' WHERE id < ', str(cut_id), ' AND ', sql_stmt_spilt[1]])
                        rows_extend = select_with_cold(_sql_stmt=sql_stmt, _order_key=order_key, _desc=desc,
                                                       _limit=limit, _begin_at=begin_at, _end_at=end_at)
                    else:
                        sql_stmt = ''.join([sql_stmt_spilt[0], ' WHERE id > ', str(cut_id), ' AND ', sql_stmt_spilt[1]])
                        rows_extend = select(_object_db=object_db_map[self.name], _sql_stmt=sql_stmt,
                                             _order_key=order_key, _desc=desc, _limit=limit)

                    rows.extend(rows_extend)
                    rows = rows[:limit]

        else:
            rows = select(_object_db=object_db_map[self.name], _sql_stmt=sql_stmt, _order_key=order_key, _desc=desc,
                          _limit=limit)

        return rows

    def delete(self):
        # 冷库数据不支持删,改
        self.previewing()
        self.previewing_has_kv()

        def _delete(_cnx=None):
            if _cnx is None:
                raise ValueError('cnx must not None')

            object_rules = objects_model['rules'][self.name]
            cursor = _cnx.cursor(dictionary=True, buffered=False)
            sql_stmt = ''.join(['DELETE FROM ', self.name, ' WHERE id = ',
                                str(RecodeOperate.get_fit_statement(object_rules, 'id', self.kv['id']))])
            try:
                cursor.execute(sql_stmt)
                _cnx.commit()
            except mysql.connector.Error as err:
                ret = dict()
                ret['state'] = ji.Common.exchange_state(50050)
                ret['state']['sub']['zh-cn'] = ''.join([ret['state']['sub']['zh-cn'], ': ', err._full_msg])
                logger.error(sql_stmt)
                logger.error(err)
                raise ji.PreviewingError(json.dumps(ret))
            finally:
                cursor.close()
                _cnx.close()

        object_db = object_db_map[self.name]

        if object_db['rw_mode'] == states.RWMode.SW_SR.value:
            object_sw_list = object_db['db_group_s'][0]
            for cnxpool in object_sw_list:
                cnx = cnxpool.get_connection()
                _delete(_cnx=cnx)

        elif object_db['rw_mode'] == states.RWMode.BW_AR.value:
            # TODO: 参见insert处
            object_bw_list = object_db['db_group_s'][0]
            for cnxpool in object_bw_list:
                cnx = cnxpool.get_connection()
                _delete(_cnx=cnx)

        else:
            raise OSError('What?')

    @staticmethod
    def get_select_sql_stmt(_name=None, _filter_str=None):
        if _filter_str is None:
            raise ValueError('_filter_str must not None')

        sub_sql_stmt = []
        tail_sql_stmt = []
        order_key = None
        limit = None
        desc = False
        # 当前周期起点时间戳
        begin_at = Utils.get_the_cycle_begin_ts(cycle_unit=app.config['DUMP_CYCLE'], offset=0)
        # 当前时间
        end_at = ts

        for v in _filter_str.split(','):
            v = v.strip(' ')
            sql_stmt_str, sql_stmt_list = Utils.dsl_to_sql(_name, v)
            if v[0:2] != '__':
                sub_sql_stmt.append(sql_stmt_str)
            else:
                if sql_stmt_list[0] == 'limit':
                    limit = int(sql_stmt_list[1])
                elif sql_stmt_list[0] == 'order by':
                    order_key = sql_stmt_list[1]
                    if sql_stmt_list.__len__() == 3:
                        desc = True
                elif sql_stmt_list[0] == 'range':
                    if sql_stmt_list[1].__len__() >= 1:
                        begin_at = int(sql_stmt_list[1][0])
                        end_at = int(sql_stmt_list[1][1])
                    elif sql_stmt_list[1].__len__() == 1:
                        begin_at = int(sql_stmt_list[1][0])
                    else:
                        continue

                    sql_stmt_str = ' AND '.join([' '.join([str(begin_at), ' <= ', app.config['TIME_LINE_FIELD']]),
                                                 ' '.join([app.config['TIME_LINE_FIELD'], ' < ', str(end_at)])])
                    sub_sql_stmt.append(sql_stmt_str)
                    continue

                tail_sql_stmt.append(sql_stmt_str)

        sub_sql_stmt = ' AND '.join(sub_sql_stmt)
        # 长度等于0表示没有WHERE语句什么事
        if sub_sql_stmt.__len__() > 0:
            sub_sql_stmt = ''.join([' WHERE ', sub_sql_stmt])
        tail_sql_stmt = ' '.join(tail_sql_stmt)
        sql_stmt = ''.join(['SELECT * FROM ', _name, sub_sql_stmt, ' ', tail_sql_stmt])
        if limit is None:
            limit = app.config['DB_RESULT_LIMIT']
            sql_stmt = ''.join([sql_stmt, ' limit ', str(limit)])

        return sql_stmt, order_key, desc, limit, begin_at, end_at

    @staticmethod
    def sql_stmt_fetchall(_cnx=None, _sql_stmt=None):
        if _cnx is None or _sql_stmt is None:
            raise ValueError('_cnx and _sql_stmt must not None')

        cursor = _cnx.cursor(dictionary=True, buffered=False)
        try:
            cursor.execute(_sql_stmt)
            return cursor.fetchall()
        except mysql.connector.Error as err:
            ret = dict()
            ret['state'] = ji.Common.exchange_state(50050)
            ret['state']['sub']['zh-cn'] = ''.join([ret['state']['sub']['zh-cn'], ': ', err._full_msg])
            logger.error(_sql_stmt)
            logger.error(err)
            raise ji.PreviewingError(json.dumps(ret))
        finally:
            cursor.close()
            _cnx.close()

    @staticmethod
    def get_fit_statement(object_rules, field, value):
        ele_type = object_rules[field][0]
        if ele_type == basestring:
            _s = regex_sql_str.sub('"', str(value)).strip('"')
            return ''.join(['"', _s.replace('"', '\\"'), '"'])
        elif ele_type == (int, long):
            return value
        else:
            raise TypeError(''.join(['unknown type ', str(ele_type)]))

